{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../')\n",
    "#  Torch imports\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "import torch.backends.cudnn as cudnn\n",
    "cudnn.benchmark = True\n",
    "\n",
    "# Python imports\n",
    "import numpy as np\n",
    "import tqdm\n",
    "import torchvision.models as tmodels\n",
    "from tqdm import tqdm\n",
    "import os\n",
    "from os.path import join as ospj\n",
    "import itertools\n",
    "import glob\n",
    "import random\n",
    "\n",
    "#Local imports\n",
    "from data import dataset as dset\n",
    "from models.common import Evaluator\n",
    "from models.image_extractor import get_image_extractor\n",
    "from models.manifold_methods import RedWine, LabelEmbedPlus, AttributeOperator\n",
    "from models.modular_methods import GatedGeneralNN\n",
    "from models.symnet import Symnet\n",
    "from utils.utils import save_args, UnNormalizer, load_args\n",
    "from utils.config_model import configure_model\n",
    "from flags import parser\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "import importlib\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "args, unknown = parser.parse_known_args()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run one of the cells to load the dataset you want to run test for and move to the next section"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_mit = '../logs/graphembed/mitstates/base/mit.yml'\n",
    "load_args(best_mit,args)\n",
    "args.graph_init = '../'+args.graph_init\n",
    "args.load = best_mit[:-7] + 'ckpt_best_auc.t7'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_ut = '../logs/graphembed/utzappos/base/utzappos.yml'\n",
    "load_args(best_ut,args)\n",
    "args.graph_init = '../'+args.graph_init\n",
    "args.load = best_ut[:-12] + 'ckpt_best_auc.t7'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading arguments and dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args.data_dir = '../'+args.data_dir\n",
    "args.test_set = 'test'\n",
    "testset = dset.CompositionDataset(\n",
    "        root= args.data_dir,\n",
    "        phase=args.test_set,\n",
    "        split=args.splitname,\n",
    "        model =args.image_extractor,\n",
    "        subset=args.subset,\n",
    "        return_images = True,\n",
    "        update_features = args.update_features,\n",
    "        clean_only = args.clean_only\n",
    "    )\n",
    "testloader = torch.utils.data.DataLoader(\n",
    "    testset,\n",
    "    batch_size=args.test_batch_size,\n",
    "    shuffle=True,\n",
    "    num_workers=args.workers)\n",
    "\n",
    "print('Objs ', len(testset.objs), ' Attrs ', len(testset.attrs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_extractor, model, optimizer = configure_model(args, testset)\n",
    "evaluator = Evaluator(testset, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if args.load is not None:\n",
    "    checkpoint = torch.load(args.load)\n",
    "    if image_extractor:\n",
    "        try:\n",
    "            image_extractor.load_state_dict(checkpoint['image_extractor'])\n",
    "            image_extractor.eval()\n",
    "        except:\n",
    "            print('No Image extractor in checkpoint')\n",
    "    model.load_state_dict(checkpoint['net'])\n",
    "    model.eval()\n",
    "    print('Loaded model from ', args.load)\n",
    "    print('Best AUC: ', checkpoint['AUC'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_results(scores, exp):\n",
    "    print(exp)\n",
    "    result = scores[exp]\n",
    "    attr = [evaluator.dset.attrs[result[0][idx,a]] for a in range(topk)]\n",
    "    obj = [evaluator.dset.objs[result[1][idx,a]] for a in range(topk)]\n",
    "    attr_gt, obj_gt = evaluator.dset.attrs[data[1][idx]], evaluator.dset.objs[data[2][idx]]\n",
    "    print(f'Ground truth: {attr_gt} {obj_gt}')\n",
    "    prediction = ''\n",
    "    for a,o in zip(attr, obj):\n",
    "        prediction += a + ' ' + o + '| '\n",
    "    print('Predictions: ', prediction)\n",
    "    print('__'*50)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### An example of predictions\n",
    "closed -> Biased for unseen classes\n",
    "\n",
    "unbiiased -> Biased against unseen classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = next(iter(testloader))\n",
    "images = data[-1]\n",
    "data = [d.to(device) for d in data[:-1]]\n",
    "if image_extractor:\n",
    "    data[0] = image_extractor(data[0])\n",
    "_, predictions = model(data)\n",
    "data = [d.to('cpu') for d in data]\n",
    "topk = 5\n",
    "results = evaluator.score_model(predictions, data[2], bias = 1000, topk=topk)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for idx in range(len(images)):\n",
    "    seen = bool(evaluator.seen_mask[data[3][idx]])\n",
    "    if seen:\n",
    "        continue\n",
    "    image = Image.open(ospj( args.data_dir,'images', images[idx]))\n",
    "    \n",
    "    plt.figure(dpi=300)\n",
    "    plt.imshow(image)\n",
    "    plt.axis('off')\n",
    "    plt.show()\n",
    "    \n",
    "    print(f'GT pair seen: {seen}')\n",
    "    print_results(results, 'closed')\n",
    "    print_results(results, 'unbiased_closed')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "args.bias = 1e3\n",
    "accuracies, all_attr_gt, all_obj_gt, all_pair_gt, all_pred = [], [], [], [], []\n",
    "\n",
    "for idx, data in tqdm(enumerate(testloader), total=len(testloader), desc = 'Testing'):\n",
    "    data.pop()\n",
    "    data = [d.to(device) for d in data]\n",
    "    if image_extractor:\n",
    "        data[0] = image_extractor(data[0])\n",
    "\n",
    "    _, predictions = model(data) # todo: Unify outputs across models\n",
    "\n",
    "    attr_truth, obj_truth, pair_truth = data[1], data[2], data[3]\n",
    "    all_pred.append(predictions)\n",
    "    all_attr_gt.append(attr_truth)\n",
    "    all_obj_gt.append(obj_truth)\n",
    "    all_pair_gt.append(pair_truth)\n",
    "\n",
    "all_attr_gt, all_obj_gt, all_pair_gt = torch.cat(all_attr_gt), torch.cat(all_obj_gt), torch.cat(all_pair_gt)\n",
    "\n",
    "all_pred_dict = {}\n",
    "# Gather values as dict of (attr, obj) as key and list of predictions as values\n",
    "for k in all_pred[0].keys():\n",
    "    all_pred_dict[k] = torch.cat(\n",
    "        [all_pred[i][k] for i in range(len(all_pred))])\n",
    "\n",
    "# Calculate best unseen accuracy\n",
    "attr_truth, obj_truth = all_attr_gt.to('cpu'), all_obj_gt.to('cpu')\n",
    "pairs = list(\n",
    "    zip(list(attr_truth.numpy()), list(obj_truth.numpy())))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "topk = 1 ### For topk results\n",
    "our_results = evaluator.score_model(all_pred_dict, all_obj_gt, bias = 1e3, topk = topk)\n",
    "stats = evaluator.evaluate_predictions(our_results, all_attr_gt, all_obj_gt, all_pair_gt, all_pred_dict, topk = topk)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for k, v in stats.items():\n",
    "    print(k, v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
